package net.dwade.livechat.websocket.client;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.simp.stomp.StompSession;

import com.alibaba.fastjson.JSONObject;

import net.dwade.livechat.websocket.beans.SubscribeType;
import net.dwade.livechat.websocket.beans.WebSocketMessageResponse;
import net.dwade.livechat.websocket.beans.WebSocketRequestMessage;

/**
 * 支持时间记录的Websocket客户端
 * @author huangxf
 * @date 2017年1月4日
 */
public class TimeLoggingChatWebsocketClient extends ChatWebsocketClient {
	
	private static final Logger logger = LoggerFactory.getLogger( TimeLoggingChatWebsocketClient.class );
	
	private final Map<String, Long> msgTimeMap = new ConcurrentHashMap<String, Long>( 128 );
	
	/**
	 * chatSessionId、requestId、该消息响应时长、已发送消息、已接收消息、平均响应时长、retCode
	 */
	private final static String FORMAT = "%1$s %2$s %3$s %4$s %5$s %6$.4f %7$s";
	
	private static final AtomicLong TIME_COUNT = new AtomicLong();
	
	private static final AtomicInteger SEND_COUNT = new AtomicInteger();
	
	private static final AtomicInteger MESSAGE_COUNT = new AtomicInteger();

	public TimeLoggingChatWebsocketClient(String userAgent, String indexUrl,
			String pullUrl, String chatUrl, String domain, String httpSessionId) {
		super( userAgent, indexUrl, pullUrl, chatUrl, domain, httpSessionId );
	}
	
	@Override
	protected void beforeSendMessage(WebSocketRequestMessage message) {
		//记录发送的条数
		SEND_COUNT.incrementAndGet();
		msgTimeMap.put( message.getWebsocketRequestId(), message.getPostTime().getTime() );
		super.beforeSendMessage( message );
	}
	
	@Override
	protected void subscribeCallback( StompSession session, SubscribeType type, String textMessage ) {
		boolean isResponse =  SubscribeType.CHAT_RESPONSE.equals( type );
		if ( isResponse ) {
			String requestId = null;
			try {
				WebSocketMessageResponse resp = JSONObject.parseObject( textMessage, WebSocketMessageResponse.class );
				requestId = resp.getWebsocketRequestId();
				Long time = msgTimeMap.get( requestId );
				if ( time != null ) {
					long costTime = System.currentTimeMillis() - time;
					long total = TIME_COUNT.addAndGet( costTime );
					int sendCount = SEND_COUNT.get();
					int msgCount = MESSAGE_COUNT.incrementAndGet();
					float avg = total * 1.0f / msgCount;
					String msg = String.format( FORMAT, this.getChatSessionId(), requestId, 
							costTime, msgCount, sendCount, avg, resp.getRetCode() );
					logger.info( "WebsocketResponse {}", msg );
				}
			} catch (Exception e) {
				logger.error( "Time logging error.", e );
			} finally {
				if ( requestId != null ) {
					msgTimeMap.remove( requestId );
				}
			}
		}
	}

}
